from copy import deepcopy
from threading import Thread
import os
import os.path
import time
import pickle

from ma.commons.core.taskinprogress import *
from ma.commons.core.reducetipinfo import ReduceTIPInfo
from ma.commons.core.reducetaskrunner import ReduceTaskRunner
from ma.commons.core.processrunner import ProcessRunner
from ma.commons.core.abstractreducer import AbstractReducer
from ma.utils import location
import ma.commons.core.inputfetchernoquery as inputfetchernoquery
import ma.commons.core.constants as constants
import ma.log
import ma.const



class ReduceTIP(TaskInProgress,Thread):
    """This class derives from TaskInProgress and keeps some additional information
    and functionality that is relevant to a reduce task only
    """
    
    def __init__(self, nodetracker, reducetip_info=None, reduce_level=0, struct_id=constants.STRUCT_ID_DEFAULT):
        """arguments description:
        node-tracker is the tracker's reference this task will run on.
        reducetip_info is a ReduceTIPInfo kind of pre fabricated object sent to init.
        reduce_level is the level to which this reduce belongs.
        struct_id is the structural tag id of the associated structure.
        """
        
        self.__log = ma.log.get_logger('ma.mrimprov')
        
        TaskInProgress.__init__(self, nodetracker)
        Thread.__init__(self)
        
        #to overshadow inherited TaskInProgressInfo object with a ReduceTIPInfo object
        if reducetip_info is None:
            # a reduceTIPinfo with default setting with no affiliation to a job or task
            self.tip_info = ReduceTIPInfo(-1,-1,None)
          
        else:
            self.tip_info = deepcopy(reducetip_info)
        
        self.reduce_level = reduce_level
        self.struct_id = struct_id
        
        # populate all the keys in the Reduce TIP info object
        self.tip_info.list_of_all_keys = self.get_key_set(self.tip_info.job_id)
        
        user_process_id = self.tip_info.returnTaskID()
        
        #the task runner object that will launch the user provided by TaskRunner
        user_reduce_module = ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_module_name, self.tip_info.job_id)
        user_reduce_class = ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_class_name, self.tip_info.job_id)
        user_runner_program = location.get_src_ma_path() + ma.const.JobsXmlData.get_filepath_str_data(ma.const.xml_mrplus_redtask_runner, self.tip_info.job_id)
        
        # get the filename which holds the reduce compute misc. output
        proc_temp_filename = ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_proc_temp_filename, self.tip_info.job_id, self.tip_info.task_id)
        ProcessRunner.store_pickled_data([self.tip_info], proc_temp_filename, self.tip_info.job_id)
        self.process_runner = ProcessRunner(user_runner_program, [user_reduce_module, user_reduce_class, self.tip_info.job_id, self.tip_info.task_id], user_process_id)
        
        # list telling if any of the mapped keys are not encountered
        # by this reduce process
        self.list_of_not_seen_keys = []
        # list telling if any of the keys encountered do not change
        # at this reduce process
        self.list_of_not_changed_keys = []
        # list telling if any of the keys encountered crossed the threshold
        # (if) given by the user xml
        self.list_of_thresholded_keys = []
        
        # finding if the user reducer is estimatable, and hence if logging needed
        temp_import = __import__(user_reduce_module, fromlist=[user_reduce_class])
        temp_reduce = eval('temp_import.' + user_reduce_class + '(None, None, None, [[], [], -1, -1])')
        self.estimatable = (temp_reduce.estimatable == AbstractReducer.ESTIMATOR)
        
    
    def get_key_set(self, job_id):
        """This function is used to get the key set from the map function. A
        None returned would mean that the user doesn't know how much keys are
        there going to be
        """
        
        # get the user map class
        user_map_module = ma.const.JobsXmlData.get_str_data(ma.const.xml_map_module_name, job_id)
        user_map_class = ma.const.JobsXmlData.get_str_data(ma.const.xml_map_class_name, job_id)
        
        # import the class and pull all the keys
        temp_import = __import__(user_map_module, fromlist=[user_map_class])
        temp_map = eval('temp_import.' + user_map_class + '(None, None, None)')
        
        key_list = temp_map.get_list_of_all_keys()
        if key_list != None:
            self.__log.info('Map has finite keys for reduce-id %d, job-id %d: Hence will be estimated', self.tip_info.task_id, self.tip_info.job_id)
        else:
            self.__log.info('Map don\'t have finite keys for reduce-id %d, job-id %d: Hence can\'t be estimated', self.tip_info.task_id, self.tip_info.job_id)
        
        return key_list
    
        
    def localizeTask(self):
        """Considering that the reduce intermediate output is local to NTs
        this class will pick up map output; if it is a reduce level 0 task;
        or a reduce x-1 level output if it is a reduce level x task.
        Function is called when this task receives a broadcast is received 
        from another nodetracker saying that the input for this task has been
        successfully stored locally for a  map task or a reduce.    
        
        1. talk to the master or the hdfs and ask which tasks you must pick your input from
        2.a. (master) pick up a map output file from every NT giving the reduce tasks hash value 
        2.b. (hdfs) pick up the map and reduce tasks ids from which you must pick your input   
          i. you broadcast that you want output of these tasks and who so ever has them responds
          ii. start a TCP client, you tcp the files from these nodes.
        """
        
        # create the input fetcher for shuffling data
        self.input_fetcher = inputfetchernoquery.InputFetcherNoQuery(self.nodetracker.output_server, self.nodetracker, inputfetchernoquery.InputFetcherNoQuery.MRIMPROV_TYPE)
        
        #self.tip_info.input_data = [(1,M,1),(1,R,2),(1,R,3),(1,M,4)]
        remaining_inputs = self.input_fetcher.fetch_input(self.tip_info.input_data, self.tip_info.task_id, self.tip_info.job_id)
        
        if remaining_inputs is not None and len(remaining_inputs) > 0:
            self.__log.warning('Was not able to retrieve all reduce inputs: remaining %s', str(remaining_inputs))
            return False
        
        self.__log.info('Copied all reduce inputs for reduce-id %d, job-id %d: %s', self.tip_info.task_id, self.tip_info.job_id, str(self.tip_info.input_data))
        return True
    
    
    def launchTask(self):
        """this function will launch the TaskRunner
        """
        
        self.reduce_threads = []
        
        arglist = [self.tip_info.job_id, self.tip_info.task_id, constants.REDUCE]
        self.pingtimer_id = self.nodetracker.timer.addTimer(self.task_ping_interval, self.nodetracker.healthPingFromTask,arglist)
        
        # note the starting time
        self.tip_info.start_time = time.time()
        
        # start the process ... and wait till it ends
        ret = self.process_runner.start_and_end_process()
        
        # note the end time
        self.tip_info.end_time = time.time()
        
        return ret
    
    
    def taskComplete(self):        
        """When the reduce task is complete this function will inform the node
        tracker so that it can mark on the HDFS that reduce ouput is ready for 
        input to subsequent reduce tasks
        """
        
        # do the cleanup, delete the input files
        input_dest_dir = ma.const.JobsXmlData.get_filepath_str_data(ma.const.xml_local_input_temp_dir, self.tip_info.job_id) + os.sep
        for input_data in self.tip_info.input_data:
            if input_data[1] == constants.MAP:
                input_filename = ma.const.JobsXmlData.get_str_data(ma.const.xml_map_output_filename, input_data[0], input_data[2])
            else:
                input_filename = ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_output_filename, input_data[0], input_data[2])
                
            input_filepath = os.path.join(input_dest_dir, input_filename)
            os.remove(input_filepath)
        
        self.__log.info('Reduce task %d of job %d completed processing', self.tip_info.task_id, self.tip_info.job_id)
            
        self.nodetracker.taskComplete(self.tip_info)

    
    def logOutput(self):
        """This function opens the files written and writes the average value to
        the logger for estimation
        """
        
        try:
            self.output_dest_dir = ma.const.JobsXmlData.get_filepath_str_data(ma.const.xml_local_output_temp_dir, 
                                                                    self.tip_info.job_id)
            output_filepath = self.output_dest_dir + os.sep + \
                ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_output_filename, self.tip_info.job_id, self.tip_info.task_id)
            
            fd = open(output_filepath, 'r')
            filecontents = fd.read()
            fd.close()
            
            dict = eval(filecontents)
            # check if the dictionary has only 1 value
            if len(dict) > 1 or len(dict[list(dict.keys())[0]]) > 1:
                raise RuntimeError('Reduce\'s input has an incorrect dictionary for avg: %s', str(dict))
            if len(dict) == 1 or len(dict[list(dict.keys())[0]]) == 1:
                # get the average value from the 
                estimator_log = ma.log.get_logger('ma.eststats')
                est_tuple = dict[list(dict.keys())[0]]
                estimator_log.info(str(time.time()) + "," + str(est_tuple[0][0]) + "," + str(self.tip_info.struct_data) + "," + str(self.reduce_level))
            else:
                self.__log.error('Some problem with the estimator dictionary: %s', str(dict))
        except Exception as err_msg:
            self.__log.error('Error while writing the reduced estimate to the logger: %s', err_msg)
        
        
    def log_thresholded_keys(self):
        """This function logs the thresholded keys to the ma.thresholder
        logger
        """
        
        thresholder_log = ma.log.get_logger('ma.thresholder')
        
        # write to the log all keys which crossed the threshold with their
        #    times of thresholding
        for thresholded_key in self.list_of_thresholded_keys:
            thresholder_log.info(str(thresholded_key[1]) + ',' + str(thresholded_key[0]) + ',' + str(thresholded_key[2]) + ',' + str(thresholded_key[3]))
            
        
    def run(self):
        """This function will run as this thread is started and marks the flow of 
        how this ReduceTIP will function
        """
        
        self.__log.info('Started ReduceTIP for %s', self.tip_info.returnTaskID())
        
        if not self.localizeTask():
            self.nodetracker.killTask(self.tip_info.job_id, constants.REDUCE, self.tip_info.task_id)
            self.__log.error('Couldn\'t shuffle all maps for Reduce %d', self.tip_info.task_id)
            raise IOError('The ReduceTIP couldn\'t shuffle data for Reduce %d' % self.tip_info.task_id)
                
        #launch and finish task
        if self.launchTask():
            # log the estimate, if the reduce is estimatable
            if self.estimatable:
                self.logOutput()
            
            # get the filename which holds the reduce compute misc. output
            proc_temp_filename = ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_proc_temp_filename, self.tip_info.job_id, self.tip_info.task_id)
            
            # get the pickled list data
            list = ProcessRunner.pull_pickled_data(proc_temp_filename, self.tip_info.job_id)
            
            # If the map function had a finite number of keys, read the
            # two lists of not seen key and not changed keys
            if self.tip_info.list_of_all_keys != None:
                self.list_of_not_seen_keys = list.__getitem__(0)
                self.list_of_not_changed_keys = list.__getitem__(1)
                
            # list of keys crossing threshold during computation
            self.list_of_thresholded_keys = list.__getitem__(2)
            
            # log the values which were thresholded after this computation
            # one doesn't need to worry if the Reduce function was
            #    thresholdable since if it wasn't, an empty list is returned
            self.log_thresholded_keys()
            
            #marks task as complete
            self.taskComplete()
        else:
            # TODO: Called failed task
            self.__log.error('Reduce task %d of job %d failed while processing', self.tip_info.task_id, self.tip_info.job_id)
        
        self.nodetracker.timer.removeTimer(self.pingtimer_id)
        